{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp models.mWDN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# mWDN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> multilevel Wavelet Decomposition Network (mWDN)\n",
    "\n",
    "This is an unofficial PyTorch implementation created by Ignacio Oguiza - oguiza@timeseriesAI.co"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "from tsai.imports import *\n",
    "from tsai.models.layers import *\n",
    "from tsai.models.InceptionTimePlus import *\n",
    "from tsai.models.utils import build_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "# This is an unofficial PyTorch implementation by Ignacio Oguiza - oguiza@timeseriesAI.co based on:\n",
    "\n",
    "# Wang, J., Wang, Z., Li, J., & Wu, J. (2018, July). Multilevel wavelet decomposition network for interpretable time series analysis. In Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining (pp. 2437-2446).\n",
    "# No official implementation found\n",
    "\n",
    "\n",
    "class WaveBlock(Module):\n",
    "    def __init__(self, c_in, c_out, seq_len, wavelet=None):\n",
    "        if wavelet is None:\n",
    "            self.h_filter = [-0.2304,0.7148,-0.6309,-0.028,0.187,0.0308,-0.0329,-0.0106]\n",
    "            self.l_filter = [-0.0106,0.0329,0.0308,-0.187,-0.028,0.6309,0.7148,0.2304]\n",
    "        else:\n",
    "            try: \n",
    "                import pywt\n",
    "            except ImportError: \n",
    "                raise ImportError(\"You need to either install pywt to run mWDN or set wavelet=None\")\n",
    "            w = pywt.Wavelet(wavelet)\n",
    "            self.h_filter = w.dec_hi\n",
    "            self.l_filter = w.dec_lo\n",
    "\n",
    "        self.mWDN_H = nn.Linear(seq_len,seq_len)\n",
    "        self.mWDN_L = nn.Linear(seq_len,seq_len)\n",
    "        self.mWDN_H.weight = nn.Parameter(self.create_W(seq_len,False))\n",
    "        self.mWDN_L.weight = nn.Parameter(self.create_W(seq_len,True))\n",
    "        self.sigmoid = nn.Sigmoid()\n",
    "        self.pool = nn.AvgPool1d(2)\n",
    "\n",
    "    def forward(self, x):\n",
    "        hp_1 = self.sigmoid(self.mWDN_H(x))\n",
    "        lp_1 = self.sigmoid(self.mWDN_L(x))\n",
    "        hp_out = self.pool(hp_1)\n",
    "        lp_out = self.pool(lp_1)\n",
    "        all_out = torch.cat((hp_out, lp_out), dim=-1)\n",
    "        return lp_out, all_out\n",
    "    \n",
    "    def create_W(self, P, is_l, is_comp=False):\n",
    "        if is_l: filter_list = self.l_filter\n",
    "        else: filter_list = self.h_filter\n",
    "        list_len = len(filter_list)\n",
    "        max_epsilon = np.min(np.abs(filter_list))\n",
    "        if is_comp: weight_np = np.zeros((P, P))\n",
    "        else: weight_np = np.random.randn(P, P) * 0.1 * max_epsilon\n",
    "        for i in range(0, P):\n",
    "            filter_index = 0\n",
    "            for j in range(i, P):\n",
    "                if filter_index < len(filter_list):\n",
    "                    weight_np[i][j] = filter_list[filter_index]\n",
    "                    filter_index += 1\n",
    "        return tensor(weight_np)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class mWDN(Module):\n",
    "    def __init__(self, c_in, c_out, seq_len, levels=3, wavelet=None, base_arch=InceptionTimePlus, **kwargs):\n",
    "        self.levels=levels\n",
    "        self.blocks = nn.ModuleList()\n",
    "        for i in range(levels): self.blocks.append(WaveBlock(c_in, c_out, seq_len // 2 ** i, wavelet=wavelet))\n",
    "        self._model = build_model(base_arch, c_in, c_out, seq_len=seq_len, **kwargs)\n",
    "\n",
    "    def forward(self, x):\n",
    "        for i in range(self.levels):\n",
    "            x, out_ =  self.blocks[i](x)\n",
    "            if i == 0: out = out_ if i == 0 else torch.cat((out, out_), dim=-1)\n",
    "        out = self._model(out)\n",
    "        return out\n",
    "    \n",
    "\n",
    "class mWDNBlocks(Module):\n",
    "    def __init__(self, c_in, c_out, seq_len, levels=3, wavelet=None):\n",
    "        self.levels=levels\n",
    "        self.blocks = nn.ModuleList()\n",
    "        for i in range(levels): self.blocks.append(WaveBlock(c_in, c_out, seq_len // 2 ** i, wavelet=wavelet))\n",
    "\n",
    "    def forward(self, x):\n",
    "        for i in range(self.levels):\n",
    "            x, out_ =  self.blocks[i](x)\n",
    "            if i == 0: out = out_ if i == 0 else torch.cat((out, out_), dim=-1)\n",
    "        return out\n",
    "    \n",
    "\n",
    "class mWDNPlus(nn.Sequential):\n",
    "    def __init__(self, c_in, c_out, seq_len, d=None, levels=3, wavelet=None, base_model=None, base_arch=InceptionTimePlus, **kwargs):\n",
    "\n",
    "        if base_model is None:\n",
    "            base_model = build_model(base_arch, c_in, c_out, d=d, seq_len=seq_len, **kwargs)\n",
    "        blocks = mWDNBlocks(c_in, c_out, seq_len, levels=levels, wavelet=wavelet)\n",
    "        backbone = nn.Sequential(blocks, base_model.backbone)\n",
    "        super().__init__(OrderedDict([('backbone', backbone), ('head', base_model.head)]))\n",
    "        self.head_nf = base_model.head_nf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tsai.models.TSTPlus import TSTPlus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "c_in = 3\n",
    "seq_len = 12\n",
    "c_out = 2\n",
    "xb = torch.rand(bs, c_in, seq_len).to(default_device())\n",
    "test_eq(mWDN(c_in, c_out, seq_len).to(xb.device)(xb).shape, [bs, c_out])\n",
    "model = mWDNPlus(c_in, c_out, seq_len, fc_dropout=.5)\n",
    "test_eq(model.to(xb.device)(xb).shape, [bs, c_out])\n",
    "model = mWDNPlus(c_in, c_out, seq_len, base_arch=TSTPlus, fc_dropout=.5)\n",
    "test_eq(model.to(xb.device)(xb).shape, [bs, c_out])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Sequential(\n",
       "   (0): GELU(approximate='none')\n",
       "   (1): fastai.layers.Flatten(full=False)\n",
       "   (2): LinBnDrop(\n",
       "     (0): Dropout(p=0.5, inplace=False)\n",
       "     (1): Linear(in_features=1536, out_features=2, bias=True)\n",
       "   )\n",
       " ),\n",
       " 128)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.head, model.head_nf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "c_in = 3\n",
    "seq_len = 12\n",
    "d = 10\n",
    "c_out = 2\n",
    "xb = torch.rand(bs, c_in, seq_len).to(default_device())\n",
    "model = mWDNPlus(c_in, c_out, seq_len, d=d)\n",
    "test_eq(model.to(xb.device)(xb).shape, [bs, d, c_out])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "c_in = 3\n",
    "seq_len = 12\n",
    "d = (5, 2)\n",
    "c_out = 2\n",
    "xb = torch.rand(bs, c_in, seq_len).to(default_device())\n",
    "model = mWDNPlus(c_in, c_out, seq_len, d=d)\n",
    "test_eq(model.to(xb.device)(xb).shape, [bs, *d, c_out])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": "IPython.notebook.save_checkpoint();",
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/nacho/notebooks/tsai/nbs/052_models.mWDN.ipynb saved at 2023-07-04 17:07:00\n",
      "Correct notebook to script conversion! 😃\n",
      "Tuesday 04/07/23 17:07:03 CEST\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,UklGRvQHAABXQVZFZm10IBAAAAABAAEAECcAACBOAAACABAAZGF0YdAHAAAAAPF/iPh/gOoOon6w6ayCoR2ZeyfbjobxK+F2Hs0XjKc5i3DGvzaTlEaraE+zz5uLUl9f46fHpWJdxVSrnfmw8mYEScqUP70cb0Q8X41uysJ1si6Eh1jYzXp9IE2DzOYsftYRyoCY9dJ/8QICgIcEun8D9PmAaBPlfT7lq4MFIlh61tYPiCswIHX+yBaOqT1QbuW7qpVQSv9lu6+xnvRVSlyopAypbGBTUdSalrSTaUBFYpInwUpxOzhti5TOdndyKhCGrdwAfBUcXIJB69p+Vw1egB76+n9q/h6ADglbf4LvnIHfF/981ODThF4m8HiS0riJVjQ6c+/EOZCYQfJrGrhBmPVNMmNArLKhQlkXWYqhbaxXY8ZNHphLuBJsZUEckCTFVHMgNKGJytIDeSUmw4QN4Qx9pReTgb3vYX/TCBuApf75f+P5Y4CRDdN+B+tngk8c8nt03CKGqipgd13OhotwOC5x9MCAknFFcmlmtPmagFFFYOCo0qRzXMhVi57pryNmIEqJlRi8bm52PfuNM8k4dfQv+4cO12l6zCGdg3jl730uE/KAPvS+f0wEAoAsA89/XfXQgBESIn6S5luDtiC8eh/YmIfpLqt1OMp5jXg8/24MveqUNUnPZsqw0Z3yVDldnaUOqIZfXlKrm36zzWhjRhaT+r+ncHI5/otUzfd2uSt7hl/bqXtoHaCC6+mqfrAOeoDD+PJ/xf8RgLMHfH/b8GeBihZIfSXidoQSJWB52NM1iRkzz3MkxpKPbUCrbDu5d5fgTAxkSK3JoEhYD1p2omere2LZTuqYLbdWa49Cx5Dww7tyXDUnioXRkHhwJyKFvd/AfPoYy4Fl7j1/LQorgEr9/X89+0qAOAwAf13sJoL8Gkd8wt25hWIp3Heez/eKODfPcSPCzpFNRDVqf7UlmnNQKGHgqd+jgVvJVm2f265QZTpLS5byur1tpT6ajvrHq3Q2MXWIxtUCehoj8YMk5LB9hRQegeTypn+nBQWA0QHgf7f2q4C5EFt+5ucOg2YfHXtq2SSHpS0ydnTL4IxFO6pvNb4ulBdInWfcsfSc7VMmXpSmE6eeXmZThJxpsgRohEfOk86+AHCoOpOMFsx1dv8s6oYT2k17uR7ngpXod34IEJqAaPfnfyABCIBZBpl/NPI2gTQVjX134x2ExSPMeR7VtYjZMWJ0W8ftjkA/YW1durCWykvjZFKu4p9LVwVbZKNkqpxh6U+6mRC2mGq2Q3SRvsIgcpc2sIpD0Bp4uiiFhW3ecXxOGgaCDe0Vf4cLPoDv+/5/mfw1gN4KKX+17emBqBmYfBHfVYUZKFR44NBtiv41bHJUwx+RJkP1apu2VJlkTwli4qrwoo1ax1dToNCtemRSTBGXz7kJbdM/PY/Dxht0dTLziH7Ul3loJEiE0uJsfdsVTYGL8Yt/AgcMgHYA7X8S+IqAYA+QfjzpxIIVHnp7tdqzhmAstXaxzEqMETpScGC/dJP3Rmdo8LIZnOVSEF+Opxumsl1sVF+dVrE5Z6NIiZSkvVdv2zsqjdnK8HVDLlyHyNjuegogM4NA5z9+YRG9gA722H97AgOA/gSyf43zCIHdE899yuTIg3ciNXpm1jmImTDwdJPITI4RPhRugbvslbFKt2Vfr/6eTFb4W1WkY6m6YPdQjJr2tNZp3EQlko7BgXHRNz2LAc+gdwMq7IUf3R58ohtFgrbr6n7hDFWAlPr8f/T9I4CECU9/De+vgVQY5nxh4POEzybJeCTS5YnCNAZzhsRzkP1Bsmu4t4aYU07nYuerA6KWWcJYO6HHrKJjaE3Zl624UWz/QOOPjcWHc7QzdIk40yl5tCWjhIDhJX0xF4CBMvBsf10IF4Ac//Z/bPlsgAcOwn6S6n6CwxzUewLcRoYaKzV38M23i9o493CNwL6S1UUuaQe0QpvbUfdfiqglpcRccFU+nkWwambASUiVfLyqbg49xY2eyWh1hy/Sh37XjHpaIYKD7OUEfrgS5IC09MV/1gMBgKMDyH/n9N6AhhINfh7mdoMoIZt6r9fAh1cvfHXNya6N4DzDbqi8K5WWSYlmbbAdnkpV6FxJpWSo1V8DUmGb3rMRaQBG2JJgwN9wCDnNi8HNI3dKK1aG0dvHe/UciIJf6rt+Og5wgDn59X9P/xWAKQhxf2XweYH+FjB9suGVhIMlOnlo02GJhTOdc7vFyo/TQGxs2Li7lz9NwmPurBihnVi7WSWiwKvGYntOpJiOt5drKUKMkFnE8HLxNPmJ9NG4eP8mAYUv4Np8hhi3gdruSX+3CSWAwP38f8f6UoCuDPF+6Os8gnAbKnxQ3d2F0imydzDPKIuiN5lxu8EKkrFE82kftW2az1DbYImpMqTUW3FWIJ83r5hl2koJlla7+m0+PmSOZcjcdMgwS4g11iZ6qCLUg5jkxn0QFA6BWvOvfzEFBIBHAtp/Qfa3gC4RSH5y5yeD2B/8evnYS4cULgR2CMsUja47cG/QvW6UeEhXZ3+xP51GVNVdP6Zpp+1eDFM5nMeySWghR4+TNL85cD46YIyCzKJ2kCzEhoTabXtGHs+CCemJfpMPjoDe9+t/qQALgM8Gj3++8UaBqRV2fQTjO4Q3JKd5r9TgiEYyMHTxxiWPpz8jbfq585YpTJpk960xoKFXsVoTo7yq6GGMTw==\" type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#|eval: false\n",
    "#|hide\n",
    "from tsai.export import get_nb_name; nb_name = get_nb_name(locals())\n",
    "from tsai.imports import create_scripts; create_scripts(nb_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
